# Define all the service endpoint handlers here.
import json
import os
import re

from flask import Response, request, send_file
from google.protobuf.json_format import MessageToJson, ParseDict
from querystring_parser import parser

from mlflow.entities.metric import Metric
from mlflow.entities.param import Param
from mlflow.entities.run_tag import RunTag
from mlflow.protos import databricks_pb2
from mlflow.protos.service_pb2 import CreateExperiment, MlflowService, GetExperiment, \
    GetRun, SearchRuns, ListArtifacts, GetArtifact, GetMetricHistory, CreateRun, \
    UpdateRun, LogMetric, LogParam, ListExperiments, GetMetric, GetParam
from mlflow.store.artifact_repo import ArtifactRepository

# Configured in `mlflow.cli`.
store = None


def _get_request_message(request_message, from_get=False):
    if from_get and len(request.query_string) > 0:
        # This is a hack to make arrays of length 1 work with the parser.
        # for example experiment_ids%5B%5D=0 should be parsed to {experiment_ids: [0]}
        # but it gets parsed to {experiment_ids: 0}
        # but it doesn't. However, experiment_ids%5B0%5D=0 will get parsed to the right
        # result.
        query_string = re.sub('%5B%5D', '%5B0%5D', request.query_string.decode("utf-8"))
        request_dict = parser.parse(query_string, normalized=True)
        ParseDict(request_dict, request_message)
        return request_message
    request_json = json.loads(request.get_json(force=True, silent=True))
    # If request doesn't have json body then assume it's empty.
    if request_json is None:
        request_json = {}
    ParseDict(request_json, request_message)
    return request_message


def get_handler(request_class):
    """
    :param request_class: The type of protobuf message
    :return:
    """
    return HANDLERS.get(request_class, _not_implemented)


def _not_implemented():
    response = Response()
    response.status_code = 404
    return response


def _create_experiment():
    request_message = _get_request_message(CreateExperiment())
    experiment_id = store.create_experiment(request_message.name)
    response_message = CreateExperiment.Response()
    response_message.experiment_id = experiment_id
    response = Response(mimetype='application/json')
    response.set_data(MessageToJson(response_message))
    return response


def _get_experiment():
    request_message = _get_request_message(GetExperiment(), from_get=True)
    response_message = GetExperiment.Response()
    response_message.experiment.MergeFrom(store.get_experiment(request_message.experiment_id)
                                          .to_proto())
    run_info_entities = store.list_run_infos(request_message.experiment_id)
    response_message.runs.extend([r.to_proto() for r in run_info_entities])
    response = Response(mimetype='application/json')
    response.set_data(MessageToJson(response_message, preserving_proto_field_name=True))
    return response


def _create_run():
    request_message = _get_request_message(CreateRun())

    tags = [RunTag(tag.key, tag.value) for tag in request_message.tags]
    run = store.create_run(
        experiment_id=request_message.experiment_id,
        user_id=request_message.user_id,
        run_name=request_message.run_name,
        source_type=request_message.source_type,
        source_name=request_message.source_name,
        entry_point_name=request_message.entry_point_name,
        start_time=request_message.start_time,
        source_version=request_message.source_version,
        tags=tags)

    response_message = CreateRun.Response()
    response_message.run.MergeFrom(run.to_proto())
    response = Response(mimetype='application/json')
    response.set_data(MessageToJson(response_message, preserving_proto_field_name=True))
    return response


def _update_run():
    request_message = _get_request_message(UpdateRun())
    updated_info = store.update_run_info(request_message.run_uuid, request_message.status,
                                         request_message.end_time)
    response_message = UpdateRun.Response(run_info=updated_info.to_proto())
    response = Response(mimetype='application/json')
    response.set_data(MessageToJson(response_message, preserving_proto_field_name=True))
    return response


def _log_metric():
    request_message = _get_request_message(LogMetric())
    metric = Metric(request_message.key, request_message.value, request_message.timestamp)
    store.log_metric(request_message.run_uuid, metric)
    response_message = LogMetric.Response()
    response = Response(mimetype='application/json')
    response.set_data(MessageToJson(response_message, preserving_proto_field_name=True))
    return response


def _log_param():
    request_message = _get_request_message(LogParam())
    param = Param(request_message.key, request_message.value)
    store.log_param(request_message.run_uuid, param)
    response_message = LogParam.Response()
    response = Response(mimetype='application/json')
    response.set_data(MessageToJson(response_message, preserving_proto_field_name=True))
    return response


def _get_run():
    request_message = _get_request_message(GetRun(), from_get=True)
    response_message = GetRun.Response()
    response_message.run.MergeFrom(store.get_run(request_message.run_uuid).to_proto())
    response = Response(mimetype='application/json')
    response.set_data(MessageToJson(response_message, preserving_proto_field_name=True))
    return response


def _search_runs():
    request_message = _get_request_message(SearchRuns(), from_get=True)
    response_message = SearchRuns.Response()
    run_entities = store.search_runs(request_message.experiment_ids,
                                     request_message.anded_expressions)
    response_message.runs.extend([r.to_proto() for r in run_entities])
    response = Response(mimetype='application/json')
    response.set_data(MessageToJson(response_message, preserving_proto_field_name=True))
    return response


def _list_artifacts():
    request_message = _get_request_message(ListArtifacts(), from_get=True)
    response_message = ListArtifacts.Response()
    if request_message.HasField('path'):
        path = request_message.path
    else:
        path = None
    run = store.get_run(request_message.run_uuid)
    artifact_entities = _get_artifact_repo(run).list_artifacts(path)
    response_message.files.extend([a.to_proto() for a in artifact_entities])
    response_message.root_uri = _get_artifact_repo(run).artifact_uri
    response = Response(mimetype='application/json')
    response.set_data(MessageToJson(response_message, preserving_proto_field_name=True))
    return response


_TEXT_EXTENSIONS = ['txt', 'yaml', 'json', 'js', 'py', 'csv', 'MLmodel', 'MLproject']


def _get_artifact():
    request_message = _get_request_message(GetArtifact(), from_get=True)
    run = store.get_run(request_message.run_uuid)
    filename = os.path.abspath(_get_artifact_repo(run).download_artifact(request_message.path))
    extension = os.path.splitext(filename)[-1].replace(".", "")
    if extension in _TEXT_EXTENSIONS:
        return send_file(filename, mimetype='text/plain')
    else:
        return send_file(filename)


def _get_metric_history():
    request_message = _get_request_message(GetMetricHistory(), from_get=True)
    response_message = GetMetricHistory.Response()
    metric_entites = store.get_metric_history(request_message.run_uuid,
                                              request_message.metric_key)
    response_message.metrics.extend([m.to_proto() for m in metric_entites])
    response = Response(mimetype='application/json')
    response.set_data(MessageToJson(response_message, preserving_proto_field_name=True))
    return response


def _get_metric():
    request_message = _get_request_message(GetMetric(), from_get=True)
    response_message = GetMetric.Response()
    metric = store.get_metric(request_message.run_uuid, request_message.metric_key)
    response_message.metric.MergeFrom(metric.to_proto())
    response = Response(mimetype='application/json')
    response.set_data(MessageToJson(response_message, preserving_proto_field_name=True))
    return response


def _get_param():
    request_message = _get_request_message(GetParam(), from_get=True)
    response_message = GetParam.Response()
    parameter = store.get_param(request_message.run_uuid, request_message.param_name)
    response_message.parameter.MergeFrom(parameter.to_proto())
    response = Response(mimetype='application/json')
    response.set_data(MessageToJson(response_message, preserving_proto_field_name=True))
    return response


def _list_experiments():
    response_message = ListExperiments.Response()
    experiment_entities = store.list_experiments()
    response_message.experiments.extend([e.to_proto() for e in experiment_entities])
    response = Response(mimetype='application/json')
    response.set_data(MessageToJson(response_message, preserving_proto_field_name=True))
    return response


def _get_artifact_repo(run):
    if run.info.artifact_uri:
        return ArtifactRepository.from_artifact_uri(run.info.artifact_uri)

    # TODO(aaron) Remove this once everyone locally only has runs from after
    # the introduction of "artifact_uri".
    uri = os.path.join(store.root_directory, str(run.info.experiment_id),
                       run.info.run_uuid, "artifacts")
    return ArtifactRepository.from_artifact_uri(uri)


def _get_paths(base_path):
    """
    A service endpoints base path is typically something like /preview/mlflow/experiment.
    We should register paths like /api/2.0/preview/mlflow/experiment and
    /ajax-api/2.0/preview/mlflow/experiment in the Flask router.
    """
    return ['/api/2.0{}'.format(base_path), '/ajax-api/2.0{}'.format(base_path)]


def get_endpoints():
    """
    :return: List of tuples (path, handler, methods)
    """
    service_methods = MlflowService.DESCRIPTOR.methods
    ret = []
    for service_method in service_methods:
        endpoints = service_method.GetOptions().Extensions[databricks_pb2.rpc].endpoints
        for endpoint in endpoints:
            for http_path in _get_paths(endpoint.path):
                handler = get_handler(MlflowService().GetRequestClass(service_method))
                ret.append((http_path, handler, [endpoint.method]))
    return ret


HANDLERS = {
    CreateExperiment: _create_experiment,
    GetExperiment: _get_experiment,
    CreateRun: _create_run,
    UpdateRun: _update_run,
    LogParam: _log_param,
    LogMetric: _log_metric,
    GetRun: _get_run,
    SearchRuns: _search_runs,
    ListArtifacts: _list_artifacts,
    GetArtifact: _get_artifact,
    GetMetricHistory: _get_metric_history,
    ListExperiments: _list_experiments,
    GetParam: _get_param,
    GetMetric: _get_metric,
}
